"""
Phase 7: Robustness Tests
===========================
1. Leave-one-country-out jackknife on continuous rating model
2. Alternative safe threshold: AA (19) instead of AA- (18)
3. Subsample: pre/post-GFC; OECD vs non-OECD rated
4. Placebo: shuffled demographics (permutation test)

Output: table11_robustness.md, phase7_results.csv
"""

import sys
import numpy as np
import pandas as pd
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/safe_asset_cliff")
ROOT_DIR = PROJECT_DIR.parent
PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

OECD = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'CZE', 'DNK', 'FIN', 'FRA', 'DEU',
    'GBR', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN', 'KOR', 'LUX', 'MEX',
    'NLD', 'NZL', 'NOR', 'POL', 'PRT', 'ESP', 'SWE', 'CHE', 'TUR', 'USA',
}


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def write_markdown_table(path, title, headers, rows, notes=None):
    lines = [f"### {title}", ""]
    lines.append("| " + " | ".join(headers) + " |")
    lines.append("|" + "|".join(["--:" if i > 0 else ":--" for i in range(len(headers))]) + "|")
    for row in rows:
        lines.append("| " + " | ".join(str(c) for c in row) + " |")
    if notes:
        lines.append("")
        lines.append(f"*{notes}*")
    lines.append("")
    path.write_text("\n".join(lines), encoding="utf-8")
    print(f"  Saved: {path}")


def fit_rating_model(df, x_vars, label):
    """Fit PanelGLS on rating_numeric, return (model, result_df) or (None, None)."""
    cols = ['rating_numeric'] + x_vars + ['iso3', 'year']
    sub = df[[c for c in cols if c in df.columns]].dropna()
    if len(sub) < 50:
        return None, None

    model = PanelGLS()
    model.fit(sub['rating_numeric'].values, sub[x_vars].values,
              sub['iso3'].values, sub['year'].values)

    rdf = model.to_dataframe(feature_names=x_vars)
    rdf['model'] = label
    rdf['n_obs'] = model.n_obs
    rdf['n_countries'] = model.n_countries
    rdf['r_squared'] = model.r_squared
    rdf['rho'] = model.rho
    return model, rdf


def main():
    print("=" * 70)
    print("PHASE 7: Robustness Tests")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "cliff_panel.csv")
    print(f"Loaded: {df['iso3'].nunique()} countries, {len(df):,} obs")

    x_vars = ['old_dep', 'oadr_spline_20', 'rgdp_growth', 'inflation', 'kaopen']
    x_vars = [v for v in x_vars if v in df.columns]

    all_results = []
    table_rows = []

    # ================================================================
    # 1. LEAVE-ONE-COUNTRY-OUT JACKKNIFE
    # ================================================================
    print("\n" + "=" * 70)
    print("1. LEAVE-ONE-COUNTRY-OUT JACKKNIFE")
    print("=" * 70)

    # Full sample estimate
    m_full, r_full = fit_rating_model(df, x_vars, "Full sample")
    if m_full is not None:
        all_results.append(r_full)
        full_oadr = m_full.beta[x_vars.index('old_dep')]
        full_spline = m_full.beta[x_vars.index('oadr_spline_20')]

        jackknife_oadr = []
        jackknife_spline = []

        countries = sorted(df['iso3'].unique())
        for iso3 in countries:
            sub = df[df['iso3'] != iso3].copy()
            m_jack, _ = fit_rating_model(sub, x_vars, f"LOO-{iso3}")
            if m_jack is not None:
                jackknife_oadr.append(m_jack.beta[x_vars.index('old_dep')])
                jackknife_spline.append(m_jack.beta[x_vars.index('oadr_spline_20')])

        if jackknife_oadr:
            jack_arr = np.array(jackknife_oadr)
            spline_arr = np.array(jackknife_spline)
            print(f"\n  OADR coefficient jackknife:")
            print(f"    Full sample: {full_oadr:.3f}")
            print(f"    Mean LOO:    {jack_arr.mean():.3f}")
            print(f"    SD LOO:      {jack_arr.std():.3f}")
            print(f"    Range:       [{jack_arr.min():.3f}, {jack_arr.max():.3f}]")
            print(f"\n  Spline(20%) coefficient jackknife:")
            print(f"    Full sample: {full_spline:.3f}")
            print(f"    Mean LOO:    {spline_arr.mean():.3f}")
            print(f"    Range:       [{spline_arr.min():.3f}, {spline_arr.max():.3f}]")

            table_rows.append(["Jackknife OADR", f"{full_oadr:.2f}",
                               f"[{jack_arr.min():.2f}, {jack_arr.max():.2f}]",
                               f"{jack_arr.std():.2f}", str(len(jackknife_oadr))])
            table_rows.append(["Jackknife Spline(20%)", f"{full_spline:.2f}",
                               f"[{spline_arr.min():.2f}, {spline_arr.max():.2f}]",
                               f"{spline_arr.std():.2f}", str(len(jackknife_spline))])

    # ================================================================
    # 2. ALTERNATIVE SAFE THRESHOLD: AA (19)
    # ================================================================
    print("\n" + "=" * 70)
    print("2. ALTERNATIVE SAFE THRESHOLD: AA (19)")
    print("=" * 70)

    for threshold, label in [(18, "AA- threshold (baseline)"), (19, "AA threshold")]:
        df[f'safe_{threshold}'] = (df['rating_numeric'] >= threshold).astype(int)
        n_safe = df[f'safe_{threshold}'].sum()
        n_countries = df[df[f'safe_{threshold}'] == 1]['iso3'].nunique()
        print(f"  {label}: {n_safe} safe obs, {n_countries} countries ever safe")
        table_rows.append([label, f"N_safe={n_safe}", f"C={n_countries}", "-", "-"])

    # ================================================================
    # 3. SUBSAMPLE: PRE/POST-GFC, OECD vs NON-OECD
    # ================================================================
    print("\n" + "=" * 70)
    print("3. SUBSAMPLE TESTS")
    print("=" * 70)

    subsamples = [
        ("Pre-GFC (1990-2007)", df[(df['year'] >= 1990) & (df['year'] <= 2007)]),
        ("Post-GFC (2010-2024)", df[(df['year'] >= 2010) & (df['year'] <= 2024)]),
        ("OECD rated", df[df['iso3'].isin(OECD)]),
        ("Non-OECD rated", df[~df['iso3'].isin(OECD)]),
    ]

    for label, sub in subsamples:
        m, r = fit_rating_model(sub, x_vars, label)
        if m is not None:
            all_results.append(r)
            c_oadr = m.beta[x_vars.index('old_dep')]
            p_oadr = m.pvalues[x_vars.index('old_dep')]
            c_spline = m.beta[x_vars.index('oadr_spline_20')]
            p_spline = m.pvalues[x_vars.index('oadr_spline_20')]
            print(f"\n  {label} (N={m.n_obs}, {m.n_countries} countries, R²={m.r_squared:.4f})")
            print(f"    OADR: {c_oadr:.2f}{stars(p_oadr)}  "
                  f"Spline(20%): {c_spline:.2f}{stars(p_spline)}")
            table_rows.append([label, f"{c_oadr:.2f}{stars(p_oadr)}",
                               f"{c_spline:.2f}{stars(p_spline)}",
                               f"{m.n_obs:,}", f"{m.r_squared:.3f}"])

    # ================================================================
    # 4. PLACEBO: SHUFFLED DEMOGRAPHICS
    # ================================================================
    print("\n" + "=" * 70)
    print("4. PLACEBO: SHUFFLED DEMOGRAPHICS (permutation test)")
    print("=" * 70)

    est = df.dropna(subset=['rating_numeric'] + x_vars).copy()
    y = est['rating_numeric'].values
    X = est[x_vars].values
    entity = est['iso3'].values
    time = est['year'].values

    # True coefficient
    m_true = PanelGLS()
    m_true.fit(y, X, entity, time)
    true_oadr = m_true.beta[x_vars.index('old_dep')]
    true_spline = m_true.beta[x_vars.index('oadr_spline_20')]

    N_PERM = 200
    np.random.seed(123)
    perm_oadr = []
    perm_spline = []

    for perm_i in range(N_PERM):
        X_perm = X.copy()
        # Shuffle old_dep and oadr_spline_20 across observations
        idx = np.random.permutation(len(X_perm))
        X_perm[:, x_vars.index('old_dep')] = X[idx, x_vars.index('old_dep')]
        X_perm[:, x_vars.index('oadr_spline_20')] = X[idx, x_vars.index('oadr_spline_20')]

        try:
            m_perm = PanelGLS()
            m_perm.fit(y, X_perm, entity, time)
            perm_oadr.append(m_perm.beta[x_vars.index('old_dep')])
            perm_spline.append(m_perm.beta[x_vars.index('oadr_spline_20')])
        except Exception:
            pass

    if perm_oadr:
        perm_oadr = np.array(perm_oadr)
        perm_spline = np.array(perm_spline)

        # p-value: fraction of permuted coefficients more extreme than true
        p_perm_oadr = np.mean(np.abs(perm_oadr) >= np.abs(true_oadr))
        p_perm_spline = np.mean(np.abs(perm_spline) >= np.abs(true_spline))

        print(f"  True OADR: {true_oadr:.3f}, Permutation p={p_perm_oadr:.3f}")
        print(f"  True Spline: {true_spline:.3f}, Permutation p={p_perm_spline:.3f}")
        print(f"  Permuted OADR: mean={perm_oadr.mean():.3f}, "
              f"SD={perm_oadr.std():.3f}")
        print(f"  Permuted Spline: mean={perm_spline.mean():.3f}, "
              f"SD={perm_spline.std():.3f}")

        table_rows.append(["Placebo OADR", f"{true_oadr:.2f}",
                           f"perm_p={p_perm_oadr:.3f}",
                           f"perm_mean={perm_oadr.mean():.2f}",
                           f"N_perm={N_PERM}"])
        table_rows.append(["Placebo Spline(20%)", f"{true_spline:.2f}",
                           f"perm_p={p_perm_spline:.3f}",
                           f"perm_mean={perm_spline.mean():.2f}",
                           f"N_perm={N_PERM}"])

    # ── Write table ──
    if table_rows:
        write_markdown_table(
            TABLES_DIR / "table11_robustness.md",
            "Table 11: Robustness Tests",
            ["Test", "OADR Coef / Stat", "Range / p-value", "SD / Detail", "N"],
            table_rows,
            notes="PanelGLS with AR(1). Dependent variable: rating_numeric (21-point scale). "
                  "Jackknife: leave-one-country-out. Placebo: 200 permutations of demographic variables."
        )

    # ── Save results ──
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(TABLES_DIR / "phase7_results.csv", index=False)
        print(f"\n  Saved: {TABLES_DIR / 'phase7_results.csv'}")

    print("\n" + "=" * 70)
    print("Phase 7 complete.")
    print("=" * 70)


if __name__ == "__main__":
    main()
